from tqdm import tqdm

import torch
import torch.nn.functional as F
from torch.nn import ModuleList, Linear, BatchNorm1d
from torch_geometric.nn import GCNConv


def get_model(model_name, num_features, num_classes, args):
    if model_name in ["linear"]:
        return Lin(num_features=num_features, num_classes=num_classes)
    elif model_name in ["mlp"]:
        return MLP(num_features=num_features, num_classes=num_classes, hidden_dim=args.hidden_dim, num_layers=2, dropout=args.dropout)
    elif model_name in ["gcn"]:
        return GNN(num_features=num_features, num_classes=num_classes, num_layers=2, hidden_dim=args.hidden_dim, dropout=args.dropout, conv_type=model_name)
    else:
        raise ValueError(f"Model {model_name} not supported")


class GNN(torch.nn.Module):
    def __init__(self, num_features, num_classes, hidden_dim, num_layers=2, dropout=0, conv_type="GCN"):
        super(GNN, self).__init__()
        self.convs = ModuleList([get_conv(conv_type, num_features, hidden_dim)])
        for i in range(num_layers - 2):
            self.convs.append(get_conv(conv_type, hidden_dim, hidden_dim))
        self.convs.append(get_conv(conv_type, hidden_dim, num_classes)) 
        self.num_layers = num_layers
        self.dropout = dropout

    def forward(self, x, edge_index, edge_weight):
        for conv in self.convs[:-1]:
            x = conv(x, edge_index, edge_weight).relu_()
            x = F.dropout(x, p=self.dropout, training=self.training)

        out = self.convs[-1](x, edge_index, edge_weight)

        return torch.nn.functional.log_softmax(out, dim=1)


class MLP(torch.nn.Module):
    def __init__(self, num_features, num_classes, hidden_dim, num_layers, dropout):
        super(MLP, self).__init__()
        self.dropout = dropout

        self.lins = ModuleList([Linear(num_features, hidden_dim)])
        self.bns = ModuleList([BatchNorm1d(hidden_dim)])
        for _ in range(num_layers - 2):
            self.lins.append(Linear(hidden_dim, hidden_dim))
            self.bns.append(BatchNorm1d(hidden_dim))
        self.lins.append(Linear(hidden_dim, num_classes))

    def forward(self, x, edge_index, edge_weight):
        for lin, bn in zip(self.lins[:-1], self.bns):
            x = bn(lin(x).relu_())
            x = F.dropout(x, p=self.dropout, training=self.training)
        x = self.lins[-1](x)

        return torch.nn.functional.log_softmax(x, dim=1)


class Lin(torch.nn.Module):
    def __init__(self, num_features, num_classes):
        super(Lin, self).__init__()
        self.lin = Linear(num_features, num_classes)

    def forward(self, x, edge_index, edge_weight):
        x = self.lin(x)

        return torch.nn.functional.log_softmax(x, dim=1)


def get_conv(conv_type, input_dim, output_dim):
    if conv_type == "gcn":
        return GCNConv(input_dim, output_dim)
    else:
        raise ValueError(f"Convolution type {conv_type} not supported")
